import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import minimize

class ODE1d:
    def __init__(self,z_i,dt,nT,g0,x0,nSamples,save_data=True,x_speed=1.e0):
        self.z_i             = z_i
        self.dt              = dt
        self.g0              = g0/np.sum(g0) # distribution of stationary distribution
        self.x               = x0
        self.x0              = np.copy(x0)
        if save_data:
            self.x_history   = np.zeros(nT)
        else:
            self.x_history   = np.zeros(2) # save at 3 time steps
            self.saveat      = [np.floor(nT/5.),np.floor(nT/3.)]
        self.save_data       = save_data
        self.nSamples        = nSamples
        self.loss            = np.zeros(nT-1)#loss for x player
        self.x_speed         = x_speed

    def update_x(self,rho_bar,t):
        '''
        Runge-Kutta 3rd order method
        eta is discretization step size
        '''
        dt = self.dt
        x  = self.x
        if self.x_speed=="best response":
            # solve explicitly
            N_samples            = self.nSamples
            rho                  = rho_bar
            normalized_rho       = rho/np.sum(rho)
            Z                    = np.random.choice(self.z_i,size=N_samples,p=normalized_rho) #sample 0 labels, choose from z indices with probability rho
            positive_labels      = np.random.choice(self.z_i,size=N_samples,p=self.g0) # sample positive labels
            self.positive_labels = positive_labels
            pos_exp              = np.exp(-3.*(self.positive_labels-x))
            neg_exp              = np.exp(-3.*(Z-x))
            self.negative_labels = Z

            eqn = lambda x: np.sum(np.log(1.+pos_exp)) + np.sum(np.log(1.+neg_exp)) - np.sum(-3.*(Z-x))
            jac = lambda x: 3.*np.sum(1./(1.+np.exp(3.*(self.positive_labels-x)))) / len(pos_exp) + 3.*np.sum(1/(1.+np.exp(3.*(Z-x)))) / len(neg_exp) - 3.
            opt = minimize(eqn,1.,jac=jac,method="SLSQP")


            # save population loss for later
            loss = np.sum(np.log(1.+pos_exp)) + np.sum(np.log(1.+neg_exp)) - np.sum(-3.*(Z-x))
            self.loss[t-1] = loss/len(pos_exp)
            if opt.success:
                x = opt.x[0]
                self.x = x
            else:
                print("x not optimized")
                print(opt.message)
        
        else:
            # compute gradient with weight given by x_speed
            eta_ = self.x_speed
            k1 = self._get_grad(x,              rho_bar,t) # +0
            k2 = self._get_grad(x+dt/2*k1,      rho_bar,t) # +h/2
            k3 = self._get_grad(x-dt*k1+2*dt*k2,rho_bar,t) # +h
            self.x  -= dt/6.*(k1 + 4*k2 + k3)*eta_
            

        if self.save_data:
            self.x_history[t] = np.copy(self.x)
        elif t == self.saveat[0]:
            self.x_history[0] = np.copy(self.x)
        elif t == self.saveat[1]:
            self.x_history[1] = np.copy(self.x)
        return self.x

    def _get_grad(self,current_x,rho,t):
        '''
        updates the value of x according to the ODE x_dot = -grad f(x,Z)
        '''
        dt = self.dt
        N_samples = self.nSamples
        
        # sample negative labels
        normalized_rho = rho/np.sum(rho)
        Zt = np.random.choice(self.z_i,size=N_samples,p=normalized_rho) # choose from z indices with probability rho

         # sample positive labels
        positive_labels = np.random.choice(self.z_i,size=N_samples,p=self.g0)
        self.positive_labels = positive_labels

        # compute gradient
        d_x = dt*self._grad_f(current_x,Zt,t)
        return d_x

    def _grad_f(self,x,Z,t):
        '''returns the gradient of the cost for x'''
        # compute gradient
        pos_exp = np.exp(-3.*(self.positive_labels-x))
        neg_exp = np.exp(-3.*(Z-x))
        self.negative_labels = Z # save so that loss can be computed in PDE object

        # save population loss for later
        loss = np.sum(np.log(1.+pos_exp)) + np.sum(np.log(1.+neg_exp)) - np.sum(-3.*(Z-x))
        self.loss[t-1] = loss/len(pos_exp)
        grad = 3.*np.sum(1./(1.+np.exp(3.*(self.positive_labels-x)))) / len(pos_exp) + 3.*np.sum(1/(1.+np.exp(3.*(Z-x)))) / len(neg_exp) - 3.  #/len(neg_exp)

        return grad